{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import OneCycleLR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Argument():\n",
    "    def __init__(self, batch_size=256, test_batch_size=1000, epochs=2, lr=4.0, no_cuda=False, save_model=False):\n",
    "        \n",
    "        self.batch_size = batch_size\n",
    "        self.test_batch_size = test_batch_size\n",
    "        self.epochs = epochs\n",
    "        self.lr = lr\n",
    "        self.no_cuda = no_cuda\n",
    "        self.save_model = save_model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Argument()\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%%time\n",
    "\n",
    "#import data_saver\n",
    "#data_saver.save_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 1.94 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "train_data = torch.load('data/train_data.pt', map_location=device)\n",
    "train_labels = torch.load('data/train_labels.pt', map_location=device)\n",
    "\n",
    "test_data = torch.load('data/test_data.pt', map_location=device)\n",
    "test_labels = torch.load('data/test_labels.pt', map_location=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout2 = nn.Dropout2d(0.5)\n",
    "        self.fc1 = nn.Linear(1600, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output\n",
    "\n",
    "def train(args, model, device, optimizer, scheduler):\n",
    "    batch_size = args.batch_size\n",
    "    model.train()\n",
    "    num_batches = (len(train_data)-1)//batch_size + 1\n",
    "    for i in range(num_batches):\n",
    "        data = train_data[batch_size*i:batch_size*(i+1)]\n",
    "        target = train_labels[batch_size*i:batch_size*(i+1)]\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "    \n",
    "\n",
    "\n",
    "def test(args, model, device, epoch):\n",
    "    model.eval()\n",
    "    batch_size = args.test_batch_size\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for i in range(len(test_data)//batch_size):\n",
    "            data = test_data[batch_size*i:batch_size*(i+1)]\n",
    "            target = test_labels[batch_size*i:batch_size*(i+1)]\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss\n",
    "            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_data)\n",
    "\n",
    "    print('Epoch {} Test set: Average loss: {:.4f}, Accuracy: ({:.2f}%)\\n'.format(\n",
    "        epoch, test_loss, 100. * correct / len(test_data)))\n",
    "    return 100.*correct / len(test_data)\n",
    "\n",
    "\n",
    "def main():\n",
    "    model = Net().to(device)\n",
    "    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)\n",
    "\n",
    "    scheduler = OneCycleLR(optimizer, max_lr=args.lr, epochs=args.epochs,\n",
    "                           steps_per_epoch=(len(train_data)-1)//args.batch_size + 1, cycle_momentum=False)\n",
    "    \n",
    "    for epoch in range(1, args.epochs + 1):\n",
    "        train(args, model, device, optimizer, scheduler)\n",
    "        test_acc = test(args, model, device, epoch)\n",
    "\n",
    "    if args.save_model:\n",
    "        torch.save(model.state_dict(), \"mnist_cnn.pt\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Test set: Average loss: 0.0545, Accuracy: (98.18%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0294, Accuracy: (98.94%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0544, Accuracy: (98.28%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0294, Accuracy: (98.99%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0475, Accuracy: (98.41%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0258, Accuracy: (99.16%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0650, Accuracy: (97.85%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0254, Accuracy: (99.13%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0554, Accuracy: (98.33%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0279, Accuracy: (99.05%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0597, Accuracy: (98.15%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0295, Accuracy: (99.00%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0665, Accuracy: (98.08%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0271, Accuracy: (99.05%)\n",
      "\n",
      "Epoch 1 Test set: Average loss: 0.0729, Accuracy: (97.59%)\n",
      "\n",
      "Epoch 2 Test set: Average loss: 0.0265, Accuracy: (99.10%)\n",
      "\n",
      "2.31 s ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
